# Share对象防止缓存击穿-Python Asyncio实践

# 参考

# 前记

本文描述的是如何基于Asyncio.Future​的特性编写一个语言级别的防缓存击穿的工具–Share​,并介绍它的使用和高并发下的处理方法。

# 1.缓存击穿

在后端服务中,大部分的系统瓶颈都集中在DB上,为了提升服务性能和减轻DB压力,一般会添加一层缓存层,然后每个请求都会先从缓存层获取数据,如果获取不到数据则先从DB系统中获取数据,获取到数据后才把数据放置在缓存层中再返回数据。
不过缓存层缓存的数据是有过期时间的,当设置的值过期时会导致所有请求无法从缓存层获取到数据,进而蜂拥冲向DB系统获取数据,这就是缓存击穿。由于缓存击穿是瞬时性且量级很大,所以它容易快速的提升DB系统压力,甚至打崩DB系统,流程如图:

从图中可以看到,在缓存值失效后,所有请求会先后击穿缓存并请求到DB系统,DB系统从被击穿到缓存值被设置的这段时间会执行大量相同的查询,这些查询除了浪费系统资源外还会提升系统压力,为此,大部分业务会使用加锁来解决这个问题。

在加锁后,整个请求的流程就会变为先访问缓存层,在发现缓存层没有对应数据时(缓存失效),请求会先去请求锁,当请求到锁的请求才可以去DB系统查询,并在缓存系统设置缓存值,而获取不到锁的请求只能等待锁释放后从缓存系统中获取值并返回,如图:

通过图可以看到,在加锁后,访问DB系统的同类请求只剩一个了,这样一来可以减轻DB系统的压力,但是在采用加锁逻辑后会把压力从DB系统转移给了负责锁的系统,只是锁系统能容忍的上限会比DB系统高很多。
此外,如果这个锁系统是一个分布式锁,那么此时的锁系统也是一个热点值,后端服务与分布式锁系统之间会因为大量的请求获取锁而产生许多IO。

# 2.语言级别的解决方案

为了在解决缓存击穿的问题,同时减少缓存击穿时导致不同系统的IO交互次数变多的情况,新的解决方案必须是编程语言级别的,而不是一个单独的组件。同时,这个解决方案除了能兜住大量缓存击穿的请求外,还需要做到只让其中的第一个访问的请求能够命中DB系统获取值再返回且拿到的值又能跟其他请求共享。

由于这个解决方案会在多个请求之间共享值,所以我取名为Share​,它在系统架构中的位置如图:

通过图可以发现Share​的位置与锁一样,不过具体逻辑却会有不同,如果仔细研究它的逻辑,会发现它的逻辑与Asyncio.Future​类似。

比如在asyncio.Future​的使用过程中,不同的协程可以通过await asyncio.Future()​方法获取到已经被设置的结果,同时,如果这个值还没设置,其他协程在调用await asyncio.Future()​时会一直被阻塞,直到其他协程通过set_result​设置结果。

Note: 在下面介绍Share​中将以某个协程调用代替请求的操作

有了asyncio.Future​​后,Share​​的实现会变得很轻松,只要再实现如何放行第一个协程的执行即可。Share​​实现的第一步是定义一个类似于如下的数据结构:

Dict[str, asyncio.Future]

这个数据结构是一个Dict,其中它的key是这类协程的标识,然后再根据这个数据结构添加对应的逻辑:

  • 当协程通过Share​被调用时,根据key判断是否有同类协程
  • 如果没有则初始化一个asyncio.Future​,然后再执行这个协程,在协程执行完毕时把协程的返回值设置在asyncio.Future​中并从字典中删除这个key以及返回数据。
  • 如果有则调用await asyncio.Future​等待第一个共享协程的返回值。

具体代码如下:

import asyncio
from typing import Any, Dict, Callable

# 创建一个全局的字典,用于存储 Future 对象
future_dict: Dict[str, asyncio.Future] = {}

# 定义一个函数 share,接受一个标识符 key、一个函数 fn 和一个参数 param(可选)
async def share(key: str, fn: Callable, param: Any = None) -> Any:
    # 如果 key 不在 future_dict 中,则执行以下代码
    if key not in future_dict:
        try:
            # 创建一个 asyncio.Future 对象
            future = asyncio.Future()
            # 将该 Future 对象添加到 future_dict 中,以 key 作为标识符
            future_dict[key] = future
            # 调用传入的函数 fn,并等待其执行完成,将结果设置为 Future 的结果
            future.set_result(await fn(*(param or ())))
        finally:
            # 无论执行成功还是出现异常,都会在最后将 key 对应的 Future 从 future_dict 中移除
            future_dict.pop(key, None)
    else:
        # 如果 key 已经在 future_dict 中存在,则直接获取对应的 Future 对象
        future = future_dict[key]
  
    # 返回 Future 对象的结果
    return await future

# 定义一个异步函数 delay_return,接受一个整数参数 duration
async def delay_return(duration: int) -> int:
    # 等待指定的时间长度
    await asyncio.sleep(1)
    # 返回传入的参数作为结果
    return duration

# 定义一个异步函数 main,用于测试 share 函数是否能够正常运行
async def main() -> None:
    # 创建一个任务列表,包含了调用 share 函数的多个任务
    task_list = [share("demo", delay_return, (i, )) for i in range(10)]
    # 等待所有任务完成
    done, _ = await asyncio.wait(task_list)
    # 输出所有任务的结果
    print([future.result() for future in done])

# 运行主函数
asyncio.run(main())

在运行后可以发现,不同协程的初始化参数虽然是不同的,但是他们的结果是一样的(结果取决于哪个协程先运行),比如我这次运行后它的所有结果都为3,如下:

[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]

这个结果意味着语言级别的兜底逻辑的没问题的,但是它还有一些问题仍然需要解决。

在一开始时,我也是简单的实现了一个工具函数来解决缓存击穿的问题,但是在线上运行一段时间后,发现这个工具函数仍有一些小问题需要解决,于是对它进行了一些复杂化处理​,使其能够拓展并解决一些高并发的问题,同时也提升了易用性。

Share​整个实现分为两部分,第一部份是一个名为Token​的类,它的底层就是一个asyncio.Future​,而提供的方法都是基于asyncio.Future​的封装,代码如下:

import asyncio
from typing import Any, Optional, Union, TypeVar, Generic

_Tp = TypeVar('_Tp')

class Token(Generic[_Tp]):
    def __init__(self, key: Any):
        self._key: Any = key
        self._future: Optional[asyncio.Future[_Tp]] = None
  
    def can_do(self) -> bool:
        # 初始化 future 并判断是否执行后续操作
        # 这个逻辑可能有点怪,但是暂时没有想到更好的办法
        if not self._future:
            self._future = asyncio.Future()
            return True
        return False
  
    def is_done(self) -> bool:
        # 判断是否执行完成
        return self._future is not None and self._future.done()
  
    async def await_done(self) -> _Tp:
        # 获取设置在 future 的结果
        if not self._future:
            raise RuntimeError(f"You should use Token<{self._key}>.can_do() before Token<{self._key}>.await_done()")
        if not self._future.done():
            await self._future
        return self._future.result()
  
    def set_result(self, result: Union[_Tp, Exception]) -> bool:
        # 设置结果到 future 中,需要注意的是,如果是异常,需要通过 `set_exception` 设置异常,否则在设置异常后调用 `await asyncio.Future` 时不会抛出错误。
        if self._future and not self._future.done():
            if isinstance(result, Exception):
                self._future.set_exception(result)
            else:
                self._future.set_result(result)
            return True
        return False

而第二部分就是Share​​的主体部分了,代码如下:

from typing import Any, Callable, Coroutine, Dict, TypeVar

# 定义一个类型别名 ShareKeyType,用于标识 token 的键类型
_ShareKeyType = TypeVar('_ShareKeyType')

# 定义一个类型别名 P,用于参数的类型标注
P = TypeVar('P', tuple, None)

# 定义一个类型别名 R_T,用于函数返回值的类型标注
R_T = TypeVar('R_T')

class Share(object):
    def __init__(self) -> None:
        # 初始化存储 token 的容器
        self._token_dict: Dict[_ShareKeyType, Token] = dict()

    def _get_token(self, key: _ShareKeyType) -> Token:
        # 获取 token 的简单封装
        if key not in self._token_dict:
            self._token_dict[key] = Token(key)
        return self._token_dict[key]

    async def _do_handle(
        self,
        key: _ShareKeyType,
        func: Callable[P, Coroutine[Any, Any, R_T]],
        args: P = None,
        kwargs: P = None
    ) -> R_T:
        token: Token = self._get_token(key)
        try:
            # 判断是否可以执行操作
            if token.can_do():
                # 如果可以则执行
                try:
                    # 调用传入的函数 func,并将结果设置到 token 中
                    token.set_result(await func(*(args or ()), **(kwargs or {})))
                except Exception as e:
                    # 存储异常值到 token 中
                    token.set_result(e)
            # 通过 token 获取值并返回,没有值则会阻塞
            return await token.await_done()
        finally:
            # 使用完毕后删除 token
            self._token_dict.pop(key, None)

    def do(
        self,
        key: _ShareKeyType,
        func: Callable[P, Coroutine[Any, Any, R_T]],
        args: P = None,
        kwargs: P = None,
    ) -> Coroutine[Any, Any, R_T]:
        # 执行操作的入口,调用 _do_handle 处理函数调用和返回值
        return self._do_handle(key, func, args, kwargs)

通过代码可以发现Share​的主体逻辑非常简单,其中_do_handle​的逻辑与第二节中的share​函数类似,而新增的do​方法只是_do_handle​的一层封装,它在采用了PEP-612​的类型标注后,使用者可以方便的从编辑器知道do的返回类型,接下来通过一段代码来检查Share​是否正常,如下:

async def delay_return(duration: int) -> int:
    await asyncio.sleep(1)
    return durationasync def main() -> None:
    share = Share()
    task_list = [share.do("demo", delay_return, (i, ))for i in range(10)]
    done, _ = await asyncio.wait(task_list)
    print([future.result() for future in done])asyncio.run(main())

在运行代码后输出如下(值可能不同):

[3, 3, 3, 3, 3, 3, 3, 3, 3, 3] 

通过结果可以发现Share​运行正常,毕竟它的实现逻辑与share​函数类似,但是当把鼠标移动到task_list​上面可以发现,由于do​方法采用了PEP-612​的类型标注后,编辑器可以展示它的类型了,如下:

此外,基于_do_handle​可以开发出一个装饰器,这样用起来就非常方便了,使用方法如下:

async def main() -> None:
    share = Share()    @share.wrapper_do()
    async def delay_return(duration: int) -> int:
        await asyncio.sleep(1)
        return duration    task_list = [share.do("demo", delay_return, (i, ))for i in range(10)]
    done, _ = await asyncio.wait(task_list)
    print([future.result() for future in done])asyncio.run(main())

对应的实现方法见源码 (opens new window)

# 4.高并发下的问题

现在通过Share​可以解决缓存击穿的问题了,但是与其他中间层一样,在引入Share​之后会产生其他的严重的问题。

假设有这样一个场景,这个场景使用的DB系统有一个奇葩的Bug,这个Bug会导致每有1w次请求就有一个请求会被堵塞10秒,在未引入缓存击穿的保护逻辑之前,并不会有什么太大的影响,因为它的影响面很小,毕竟平均下来一个用户一天也就遇到几次,但是在引入缓存击穿保护的逻辑之后,就需要考虑这个问题对系统的影响了。
因为缓存击穿保护逻辑放行的请求在通过DB获取数据时,刚好遇到了Bug而堵塞了10秒,导致这个请求被堵住10秒后才能获取到值,这样会导致所有经过Share​的请求在10秒内都被堵住,而这时影响面就非常大了。
首先它会影响到使用这个功能的所有接口在这10秒内这些功能无法使用,其次是这些请求会占用文件描述符和内存等资源,在占用过多时会影响其它服务的使用进而造成服务雪崩,为此需要对Share​进行改进,防止单个请求异常而影响到其他地方的问题。

使用asyncio.wait​之类的带有超时异常机制的方法来执行也是可以的,因为Python Asyncio​的异常传递性,无论是asyncio.wait(share.do(xxx), timeout=xxx)​还是share.do(asyncio.wait(xxx, timeout=xxx))​,第一个被放行的协程在执行超时后抛出的异常会传递给其他协程。

# 4.1.放行指定比例的请求

目前的Share​只会允许第一个协程能被真正的执行,如果可以按照一定的几率放行请求,那么就能在防止请求堵住与降低DB压力之间做到一个平衡。具体的代码实现如下(只列出变更的方法):

class Share(object):
    def __init__(self, rate: Optional[Tuple[int, int]] = None) -> None:
        # 当rate = (1, 100)时代表是百分之一
        # 当rate = (1, 1000)时代表是千分之一
        if rate and rate[0] > rate[1]:
            raise ValueError(f"rate[0] should less than rate[1], but {rate[0]} > {rate[1]}")
        self._rate: Optional[Tuple[int, int]] = rate
        self._token_dict: Dict[_ShareKeyType, Token] = dict()        ...    async def _do_handle(
        self, key: _ShareKeyType, func: Callable[P, Coroutine[Any, Any, R_T]], args: P.args, kwargs: P.args
    ) -> R_T:
        token: Token = self._get_token(key)
        try:
            can_do = token.can_do()
            if not can_do and self._rate:
                can_do = random.randint(self._rate[0], self._rate[1]) == self._rate[0]
            if can_do:
                try:
                    # 多个请求也无所谓,Token会确保只有一个请求执行
                    token.set_result(await func(*(args or ()), **(kwargs or {})))
                except Exception as e:
                    token.set_result(e)
            return await token.await_done()
        finally:
            self._token_dict.pop(key, None)

代码中Share​在__init__​方法添加了一个新的参数rate​,并在_do_handle​方法中使用到,新的_do_handle​方法除了会放行第一个协程外,其他的协程会通过rate​来决定是否放行,具体的逻辑是调用者在通过Share​的_do_handle​执行协程时,_do_handle​在判断不允许放行后会使用random​模块根据rate​生成一个随机数,如果生成的随机数与rate[0]​相等时就会放行请求,现在改进第三节的测试代码以便验证rate​的效果,具体代码如下:

async def delay_return(duration: int) -> int:
    # 由于结果只有一个,所以需要打印出来才能判断是否有多个协程被放行
    print(f"I go it, {duration}")
    await asyncio.sleep(1)
    return durationasync def main() -> None:
    # 设置有1/3的放行概率
    share = Share(rate=(1, 3))
    task_list = [share.do("demo", delay_return, (i, ))for i in range(10)]
    done, _ = await asyncio.wait(task_list)
    print([future.result() for future in done])asyncio.run(main())

在运行程序后一般会看到有多条I go it, xxx​的文本输出,如下:

I go it, 2
I go it, 1
I go it, 6
I go it, 3
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

通过输出也可以看到经过Share​的处理后,所有协程获取到的结果还是取决于第一个协程执行的结果,但是确实有4个协程得到了执行了,这样一来即使有第个协程被堵住,其他协程也能够正常执行。

除了自动的按照一定比例放行协程的执行外,Share​还有两个方法可以手动放行协程的执行,调用者只需要自己根据业务场景在恰当的时间调用对应的方法也可以解决高并发下由于引入Share​而引发的问题。

比如每隔n秒钟执行一次。

# 4.2.取消被堵住的请求

Token​底层的asyncio.Future​拥有一个cancel​的方法,通过调用cancel​方法后不仅可以取消Future​还可以把取消异常传递给Future​对应的协程,进而中断协程的运行。

对于取消机制和asyncio.Future​可以参考:

于是可以通过这个方法使Share​拥有取消被堵住的请求功能, 具体的改进逻辑是先在Token​暴露出一个cancel​的方法,这个方法会尽最大的能力取消可以被取消的Future​:

class Token(Generic[_Tp]):
    ...    def cancel(self) -> bool:
        # 如果future可以被取消,则尽最大的努力取消future
        if self._future is not None and not self._future.done():
            if self._future.cancelled():
                self._future.cancel()
            else:
                self.set_result(asyncio.CancelledError())
            return True
        return False

不过Token​只是Future​的封装,调用者无法接触到Token​,所以需要在Share​添加一个cancel​方法使调用者可以通过这个方法取消因Share​影响的协程,从而释放资源占用,具体修改代码如下:

class Share(object):
    ...
    def cancel(self, key: Optional[_ShareKeyType] = None) -> None:
        if not key:
            # 如果key为空,则取消所有相关token
            for token in self._token_dict.values():
                token.cancel()
        else:
            # 不为空则按照Key取消指定的token
            self._token_dict[key].cancel()

修改完毕后编写一段代码进行验证:

import asyncio

# 定义一个协程来模拟延迟打印操作
async def delay_print(duration: int) -> int:
    await asyncio.sleep(1)
    return duration

# 定义一个用于取消共享的协程
async def cancel_in_aio(share: "Share") -> None:
    await asyncio.sleep(0.1)
    share.cancel()

# 定义一个Share类(代码片段中未提供实现)
class Share:
    def __init__(self):
        self.tokens = []  # 用于管理/取消的令牌或任务的占位符
  
    def cancel(self):
        # 取消与此共享关联的令牌或任务的逻辑
        pass

# 主函数用于协调异步任务
async def main() -> None:
    share = Share()  # 创建一个Share实例
    task_list: "List[Coroutine]" = [
        share.do("test_cancel_in_aio", delay_print, args=[i]) for i in [11, 12, 13, 14, 15, 16, 17, 18, 19]
    ]
    # 添加一个任务来取消与共享相关的所有令牌
    task_list.append(cancel_in_aio(share))
  
    t_list = [asyncio.create_task(t) for t in task_list]  # 为每个协程创建任务
    await asyncio.sleep(1)  # 等待一段时间
  
    result = []
    for t in t_list:
        # 跳过对'cancel_in_aio'协程的计数
        if t._coro.__name__ == "cancel_in_aio":  # type: ignore
            continue
        try:
            await t  # 等待每个任务完成
            result.append(1)  # 对成功完成的任务追加1
        except asyncio.CancelledError:
            result.append(0)  # 对被取消的任务追加0
  
    print(result)  # 打印结果

asyncio.run(main())  # 在asyncio事件循环中运行主函数

运行代码后可以看到输出结果如下,通过输出结果可以知道所有协程都被取消了:

1
[0, 0, 0, 0, 0, 0, 0, 0, 0]

# 4.3.忘记被堵住的请求

直接取消同一类请求也可能太狠了,它属于一个应急的方法,在业务场景中该操作可能导致多数用户在同一时间内都收到异常响应,为此Share​还引入一个forget​的功能,使Share​能忘掉当前托管的Token​,使后续的请求访问Share​时,Share​能够另起炉灶一个新的Token​,这样一来新的请求被会之前的Token​影响到。这个功能对应的改造很简单,只需要动到Share​类,如下:

class Share(object):
    ...
    def forget(self, key: _ShareKeyType) -> None:
        if key not in self._token_dict:
            raise KeyError(f"Token {key} not found")
        token = self._token_dict[key]
        if self._token_dict[key].is_done():
            raise RuntimeError(f"{token} is done")
        self._token_dict.pop(key, None)

接着在修改对应的老朋友–验证代码,如下:

async def delay_return(duration: int) -> int:
    await asyncio.sleep(1)
    return durationasync def main() -> None:
    share = Share()
    a_task = asyncio.Task(asyncio.wait([share.do("demo", delay_return, (i, ))for i in range(10)]))
    await asyncio.sleep(0.01)
    share.forget("demo")
    b_task = asyncio.Task(asyncio.wait([share.do("demo", delay_return, (i, ))for i in range(10, 20)]))
    await asyncio.sleep(0.1)
    print({future.result() for future in (await a_task)[0]})
    print({future.result() for future in (await b_task)[0]})asyncio.run(main())

这段代码会执行两批协程,第一批返回的值只有可能是0-9,而第二批的值只有可能是10-19,它们的运行间隔只有0.01秒,但是运行时长是一样的。
此外,在执行第二批之前会先调用share.forget("demo")​,使Share​忘记了自己托管过第一批协程,在运行代码后可以看到如下输出:

{4}
{15}

通过输出可以发现,第一批协程执行时间与第二批协程执行的时间虽然是一样的,但是他们共享的是不同的结果,Share​会正常的忘记掉第一批协程。

不过这个功能还是有点缺陷,假设第二批协程都能正常执行,但第一批协程还是因为被放行的协程在执行时被堵住而全都堵塞了,这是一种糟糕的情况。
大部分场景下都希望第二批协程执行完毕后,第一批协程也能共享到第一批协程的执行结果(被卡住的协程除外),于是需要对forget​的功能进行升级。
首先是Token​的改造,Token​需要在被forget​后又能在下个协程调用时重新被​起来,改造的代码如下:

class Token(Generic[_Tp]):    def __init__(self, key: Any):
        # 标记Token是否处于被忘记
        self.is_forget = False
        self._key: Any = key
        self._future: Optional[asyncio.Future[_Tp]] = None    def can_do(self) -> bool:
        if not self._future:
            self._future = asyncio.Future()
            return True
        if self.is_forget:
            # 如果该Token被忘记了,但是future还存在,那就重新记得Token,并放行该协程
            self.is_forget = False
            return True
        return False        ...

接着就是Share​的改造,主要是添加一个参数用于判断在调用forget​时是否为强制忘记​,代码如下:

class Share(object):
    def forget(self, key: _ShareKeyType, force: bool = True) -> None:
        # 添加一个参数用于是否强制忘记
        if key not in self._token_dict:
            raise KeyError(f"Token {key} not found")
        token = self._token_dict[key]
        if self._token_dict[key].is_done():
            raise RuntimeError(f"{token} is done")
        if force:
            # 如果是强制忘记则像之前一样移除Token
            self._token_dict.pop(key, None)
        else:
            # 不是强制忘记则只标记Token的属性为忘记,等待重新被记起来
            token.is_forget = True

接着运行如下测试代码:

_is_first: bool = True
async def delay_return(duration: int) -> int:
    global _is_first
    if _is_first:
        # 第一个执行的协程耗费的时间会比较久一点
        _is_first = False
        print(f"{duration} is first")
        await asyncio.sleep(3)
    else:
        await asyncio.sleep(1)
    return durationasync def main() -> None:
    share = Share()
    a_task = asyncio.Task(asyncio.wait([share.do("demo", delay_return, (i, ))for i in range(10)]))
    await asyncio.sleep(0.01)
    share.forget("demo", force=False)
    b_task = asyncio.Task(asyncio.wait([share.do("demo", delay_return, (i, ))for i in range(10, 20)]))
    await asyncio.sleep(0.1)
    # a_task会执行比较久,所以先打印b_task
    print({future.result() for future in (await b_task)[0]}, asyncio.get_event_loop().time())
    print({future.result() for future in (await a_task)[0]}, asyncio.get_event_loop().time())asyncio.run(main())

然后可以在终端中看到如下输出:

1
2
3
4 is first
{12} 291267.247258633
{12} 291269.238582831

通过输出可以发现,4是最先执行的,但是最后a​​和b​​任务的结果都是12(第二批的值),同时第一批执行完毕的时间是比第二批晚了3秒钟。

# 完整代码

import asyncio
import random
from functools import wraps
from typing import Any, Callable, Coroutine, Dict, Generic, Optional, Tuple, TypeVar, Union

from typing_extensions import ParamSpec

__all__ = ("Share", "Token")
_Tp = TypeVar("_Tp")


class Token(Generic[_Tp]):
    """Result and status of managed actions"""

    def __init__(self, key: Any):
        self.is_forget = False
        self._key: Any = key
        self._future: Optional[asyncio.Future[_Tp]] = None

    def can_do(self) -> bool:
        """Determine whether there is a future, if not, create a new future and return true, otherwise return false"""
        if not self._future:
            self._future = asyncio.Future()
            return True
        if self.is_forget:
            self.is_forget = False
            return True
        return False

    def is_done(self) -> bool:
        """Determine whether the execution is completed"""
        return self._future is not None and self._future.done()

    def cancel(self) -> bool:
        """Cancel the execution of the current action"""
        if self._future is not None and not self._future.done():
            if self._future.cancelled():
                self._future.cancel()
            else:
                self.set_result(asyncio.CancelledError())
            return True
        return False

    async def await_done(self) -> _Tp:
        """Wait for execution to end and return data"""
        if not self._future:
            raise RuntimeError(f"You should use Token<{self._key}>.can_do() before Token<{self._key}>.await_done()")
        if not self._future.done():
            await self._future
        return self._future.result()

    def set_result(self, result: Union[_Tp, Exception]) -> bool:
        """set data or exception"""
        if self._future and not self._future.done():
            if isinstance(result, Exception):
                self._future.set_exception(result)
            else:
                self._future.set_result(result)
            return True
        return False


_ShareKeyType = Union[Tuple[Any, ...], str]
P = ParamSpec("P")
R_T = TypeVar("R_T")


class Share(object):
    def __init__(self, rate: Optional[Tuple[int, int]] = None) -> None:
        if rate and rate[0] > rate[1]:
            raise ValueError(f"rate[0] should less than rate[1], but {rate[0]} > {rate[1]}")
        self._rate: Optional[Tuple[int, int]] = rate
        self._token_dict: Dict[_ShareKeyType, Token] = dict()

    def _get_token(self, key: _ShareKeyType) -> Token:
        """Get the token (if not, create a new one and return)"""
        if key not in self._token_dict:
            self._token_dict[key] = Token(key)
        return self._token_dict[key]

    def cancel(self, key: Optional[_ShareKeyType] = None) -> None:
        """Cancel the execution of the specified action, if the key is empty, cancel all"""
        if not key:
            for token in self._token_dict.values():
                token.cancel()
        else:
            self._token_dict[key].cancel()

    def forget(self, key: _ShareKeyType, force: bool = True) -> None:
        if key not in self._token_dict:
            raise KeyError(f"Token {key} not found")
        token = self._token_dict[key]
        if self._token_dict[key].is_done():
            raise RuntimeError(f"{token} is done")
        if force:
            self._token_dict.pop(key, None)
        else:
            token.is_forget = True

    async def _do_handle(
        self, key: _ShareKeyType, func: Callable[P, Coroutine[Any, Any, R_T]], args: P.args, kwargs: P.args
    ) -> R_T:
        token: Token = self._get_token(key)
        try:
            can_do = token.can_do()
            if not can_do and self._rate:
                can_do = random.randint(self._rate[0], self._rate[1]) == self._rate[0]
            if can_do:
                try:
                    # It doesn't matter if you have multiple requests,
                    # Token will ensure that only one request is executed
                    token.set_result(await func(*(args or ()), **(kwargs or {})))
                except Exception as e:
                    token.set_result(e)
            return await token.await_done()
        finally:
            self._token_dict.pop(key, None)

    def do(
        self,
        key: _ShareKeyType,
        func: Callable[P, Coroutine[Any, Any, R_T]],
        args: P.args = None,
        kwargs: P.kwargs = None,
    ) -> Coroutine[Any, Any, R_T]:
        return self._do_handle(key, func, args, kwargs)

    def wrapper_do(
        self, key: Optional[str] = None, only_include_class_param: bool = True, include_param: bool = False
    ) -> Callable:
        if only_include_class_param and include_param:
            raise ValueError("only_include_class_param and include_param can't be True at the same time")

        def wrapper(func: Callable[P, Coroutine[Any, Any, R_T]]) -> Callable[P, Coroutine[Any, Any, R_T]]:
            key_name: str = func.__qualname__ if key is None else key

            @wraps(func)
            async def wrapper_func(*args: P.args, **kwargs: P.kwargs) -> R_T:
                if include_param:
                    real_key: Tuple[Any, ...] = (key_name, tuple(args), tuple(kwargs.values()))
                else:
                    if only_include_class_param and args and args[0].__class__.__name__ in func.__qualname__:
                        real_key = (key_name + f":{id(args[0])}",)
                    else:
                        real_key = (key_name,)
                return await self._do_handle(real_key, func, args, kwargs)

            return wrapper_func

        return wrapper

    def __str__(self) -> str:
        return str(self._token_dict)